{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Load PyTorch model\n",
    "\n",
    "In this tutorial, you learn how to load an existing PyTorch model and use it to run a prediction task.\n",
    "\n",
    "We will run the inference in DJL way with [example](https://pytorch.org/hub/pytorch_vision_resnet/) on the pytorch official website.\n",
    "\n",
    "\n",
    "## Preparation\n",
    "\n",
    "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/master/jupyter/README.md)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.8.0\n",
    "%maven ai.djl.pytorch:pytorch-engine:0.8.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "%maven net.java.dev.jna:jna:5.3.0\n",
    "        \n",
    "// See https://github.com/awslabs/djl/blob/master/pytorch/pytorch-engine/README.md\n",
    "// for more PyTorch library selection options\n",
    "%maven ai.djl.pytorch:pytorch-native-auto:1.6.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.awt.image.*;\n",
    "import ai.djl.*;\n",
    "import ai.djl.inference.*;\n",
    "import ai.djl.modality.*;\n",
    "import ai.djl.modality.cv.*;\n",
    "import ai.djl.modality.cv.util.*;\n",
    "import ai.djl.modality.cv.transform.*;\n",
    "import ai.djl.modality.cv.translator.*;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.translate.*;\n",
    "import ai.djl.training.util.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Prepare your  model\n",
    "\n",
    "This tutorial assumes that you have a TorchScript model.\n",
    "DJL only supports the TorchScript format for loading models from PyTorch, so other models will need to be [converted](https://github.com/awslabs/djl/blob/master/docs/pytorch/how_to_convert_your_model_to_torchscript.md).\n",
    "A TorchScript model includes the model structure and all of the parameters.\n",
    "\n",
    "We will be using a pre-trained `resnet18` model. First, use the `DownloadUtils` to download the model files and save them in the `build/pytorch_models` folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz\", \"build/pytorch_models/resnet18/resnet18.pt\", new ProgressBar());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to do image classification, you will also need the synset.txt which stores the classification class labels. We will need the synset containing the Imagenet labels with which resnet18 was originally trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt\", \"build/pytorch_models/resnet18/synset.txt\", new ProgressBar());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Create a Translator\n",
    "\n",
    "We will create a transformation pipeline which maps the transforms shown in the [PyTorch example](https://pytorch.org/hub/pytorch_vision_resnet/).\n",
    "```python\n",
    "...\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "...\n",
    "```\n",
    "\n",
    "Then, we will use this pipeline to create the [`Translator`](https://javadoc.io/static/ai.djl/api/0.8.0/index.html?ai/djl/translate/Translator.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pipeline pipeline = new Pipeline();\n",
    "pipeline.add(new Resize(256))\n",
    "        .add(new CenterCrop(224, 224))\n",
    "        .add(new ToTensor())\n",
    "        .add(new Normalize(\n",
    "            new float[] {0.485f, 0.456f, 0.406f},\n",
    "            new float[] {0.229f, 0.224f, 0.225f}));\n",
    "\n",
    "Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()\n",
    "                .setPipeline(pipeline)\n",
    "                .optApplySoftmax(true)\n",
    "                .build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Load your model\n",
    "\n",
    "Next, we will set the model zoo location to the `build/pytorch_models` directory we saved the model to. You can also create your own [`Repository`](https://javadoc.io/static/ai.djl/repository/0.8.0/index.html?ai/djl/repository/Repository.html) to avoid manually managing files.\n",
    "\n",
    "Next, we add some search criteria to find the resnet18 model and load it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Search for models in the build/pytorch_models folder\n",
    "System.setProperty(\"ai.djl.repository.zoo.location\", \"build/pytorch_models/resnet18\");\n",
    "\n",
    "Criteria<Image, Classifications> criteria = Criteria.builder()\n",
    "        .setTypes(Image.class, Classifications.class)\n",
    "         // only search the model in local directory\n",
    "         // \"ai.djl.localmodelzoo:{name of the model}\"\n",
    "        .optArtifactId(\"ai.djl.localmodelzoo:resnet18\")\n",
    "        .optTranslator(translator)\n",
    "        .optProgress(new ProgressBar()).build();\n",
    "\n",
    "ZooModel model = ModelZoo.loadModel(criteria);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Load image for classification\n",
    "\n",
    "We will use a sample dog image to run our prediction on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var img = ImageFactory.getInstance().fromUrl(\"https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg\");\n",
    "img.getWrappedImage()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Run inference\n",
    "\n",
    "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Predictor<Image, Classifications> predictor = model.newPredictor();\n",
    "Classifications classifications = predictor.predict(img);\n",
    "\n",
    "classifications"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Now, you can load any TorchScript model and run inference using it.\n",
    "\n",
    "You might also want to check out [load_mxnet_model.ipynb](https://github.com/awslabs/djl/blob/master/jupyter/load_mxnet_model.ipynb) which demonstrates loading a local model directly instead of through the Model Zoo API.\n",
    "To optimize inference performance, you might check out [how_to_optimize_inference_performance](https://github.com/awslabs/djl/blob/master/docs/pytorch/how_to_optimize_inference_performance.md)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "12.0.2+10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
